#!/usr/bin python3
""" The script to run the extract process of faceswap """

import logging
import os
import sys
from pathlib import Path

from tqdm import tqdm

from lib.faces_detect import DetectedFace
from lib.gpu_stats import GPUStats
from lib.multithreading import MultiThread, PoolProcess, SpawnProcess
from lib.queue_manager import queue_manager, QueueEmpty
from lib.utils import get_folder, hash_encode_image
from plugins.plugin_loader import PluginLoader
from scripts.fsmedia import Alignments, Images, PostProcess, Utils

tqdm.monitor_interval = 0  # workaround for TqdmSynchronisationWarning
logger = logging.getLogger(__name__)  # pylint: disable=invalid-name


class Extract():
    """ The extract process. """
    def __init__(self, arguments):
        logger.debug("Initializing %s: (args: %s", self.__class__.__name__, arguments)
        self.args = arguments
        self.output_dir = get_folder(self.args.output_dir)
        logger.info("Output Directory: %s", self.args.output_dir)
        self.images = Images(self.args)
        self.alignments = Alignments(self.args, True, self.images.is_video)
        self.plugins = Plugins(self.args)

        self.post_process = PostProcess(arguments)

        self.verify_output = False
        self.save_interval = None
        if hasattr(self.args, "save_interval"):
            self.save_interval = self.args.save_interval
        logger.debug("Initialized %s", self.__class__.__name__)

    def process(self):
        """ Perform the extraction process """
        logger.info('Starting, this may take a while...')
        Utils.set_verbosity(self.args.loglevel)
#        queue_manager.debug_monitor(1)
        self.threaded_io("load")
        save_thread = self.threaded_io("save")
        self.run_extraction()
        save_thread.join()
        self.alignments.save()
        Utils.finalize(self.images.images_found,
                       self.alignments.faces_count,
                       self.verify_output)

    def threaded_io(self, task, io_args=None):
        """ Perform I/O task in a background thread """
        logger.debug("Threading task: (Task: '%s')", task)
        io_args = tuple() if io_args is None else (io_args, )
        if task == "load":
            func = self.load_images
        elif task == "save":
            func = self.save_faces
        elif task == "reload":
            func = self.reload_images
        io_thread = MultiThread(func, *io_args, thread_count=1)
        io_thread.start()
        return io_thread

    def load_images(self):
        """ Load the images """
        logger.debug("Load Images: Start")
        load_queue = queue_manager.get_queue("load")
        for filename, image in self.images.load():
            if load_queue.shutdown.is_set():
                logger.debug("Load Queue: Stop signal received. Terminating")
                break
            if image is None or not image.any():
                logger.warning("Unable to open image. Skipping: '%s'", filename)
                continue
            imagename = os.path.basename(filename)
            if imagename in self.alignments.data.keys():
                logger.trace("Skipping image: '%s'", filename)
                continue
            item = {"filename": filename,
                    "image": image}
            load_queue.put(item)
        load_queue.put("EOF")
        logger.debug("Load Images: Complete")

    def reload_images(self, detected_faces):
        """ Reload the images and pair to detected face """
        logger.debug("Reload Images: Start. Detected Faces Count: %s", len(detected_faces))
        load_queue = queue_manager.get_queue("detect")
        for filename, image in self.images.load():
            if load_queue.shutdown.is_set():
                logger.debug("Reload Queue: Stop signal received. Terminating")
                break
            logger.trace("Reloading image: '%s'", filename)
            detect_item = detected_faces.pop(filename, None)
            if not detect_item:
                logger.warning("Couldn't find faces for: %s", filename)
                continue
            detect_item["image"] = image
            load_queue.put(detect_item)
        load_queue.put("EOF")
        logger.debug("Reload Images: Complete")

    @staticmethod
    def save_faces():
        """ Save the generated faces """
        logger.debug("Save Faces: Start")
        save_queue = queue_manager.get_queue("save")
        while True:
            if save_queue.shutdown.is_set():
                logger.debug("Save Queue: Stop signal received. Terminating")
                break
            item = save_queue.get()
            if item == "EOF":
                break
            filename, face = item

            logger.trace("Saving face: '%s'", filename)
            try:
                with open(filename, "wb") as out_file:
                    out_file.write(face)
            except Exception as err:  # pylint: disable=broad-except
                logger.error("Failed to save image '%s'. Original Error: %s", filename, err)
                continue
        logger.debug("Save Faces: Complete")

    def run_extraction(self):
        """ Run Face Detection """
        save_queue = queue_manager.get_queue("save")
        to_process = self.process_item_count()
        frame_no = 0
        size = self.args.size if hasattr(self.args, "size") else 256
        align_eyes = self.args.align_eyes if hasattr(self.args, "align_eyes") else False

        if self.plugins.is_parallel:
            logger.debug("Using parallel processing")
            self.plugins.launch_aligner()
            self.plugins.launch_detector()
        if not self.plugins.is_parallel:
            logger.debug("Using serial processing")
            self.run_detection(to_process)
            self.plugins.launch_aligner()

        for faces in tqdm(self.plugins.detect_faces(extract_pass="align"),
                          total=to_process,
                          file=sys.stdout,
                          desc="Extracting faces"):

            filename = faces["filename"]

            self.align_face(faces, align_eyes, size, filename)
            self.post_process.do_actions(faces)

            faces_count = len(faces["detected_faces"])
            if faces_count == 0:
                logger.verbose("No faces were detected in image: %s",
                               os.path.basename(filename))

            if not self.verify_output and faces_count > 1:
                self.verify_output = True

            self.output_faces(filename, faces, save_queue)

            frame_no += 1
            if frame_no == self.save_interval:
                self.alignments.save()
                frame_no = 0

        save_queue.put("EOF")

    def process_item_count(self):
        """ Return the number of items to be processedd """
        processed = sum(os.path.basename(frame) in self.alignments.data.keys()
                        for frame in self.images.input_images)
        logger.debug("Items already processed: %s", processed)

        if processed != 0 and self.args.skip_existing:
            logger.info("Skipping previously extracted frames: %s", processed)
        if processed != 0 and self.args.skip_faces:
            logger.info("Skipping frames with detected faces: %s", processed)

        to_process = self.images.images_found - processed
        logger.debug("Items to be Processed: %s", to_process)
        if to_process == 0:
            logger.error("No frames to process. Exiting")
            queue_manager.terminate_queues()
            exit(0)
        return to_process

    def run_detection(self, to_process):
        """ Run detection only """
        self.plugins.launch_detector()
        detected_faces = dict()
        for detected in tqdm(self.plugins.detect_faces(extract_pass="detect"),
                             total=to_process,
                             file=sys.stdout,
                             desc="Detecting faces"):
            exception = detected.get("exception", False)
            if exception:
                break

            del detected["image"]
            filename = detected["filename"]

            detected_faces[filename] = detected

        self.threaded_io("reload", detected_faces)

    def align_face(self, faces, align_eyes, size, filename):
        """ Align the detected face and add the destination file path """
        final_faces = list()
        image = faces["image"]
        landmarks = faces["landmarks"]
        detected_faces = faces["detected_faces"]
        for idx, face in enumerate(detected_faces):
            detected_face = DetectedFace()
            detected_face.from_dlib_rect(face, image)
            detected_face.landmarksXY = landmarks[idx]
            detected_face.load_aligned(image, size=size, align_eyes=align_eyes)
            final_faces.append({"file_location": self.output_dir / Path(filename).stem,
                                "face": detected_face})
        faces["detected_faces"] = final_faces

    def output_faces(self, filename, faces, save_queue):
        """ Output faces to save thread """
        final_faces = list()
        for idx, detected_face in enumerate(faces["detected_faces"]):
            output_file = detected_face["file_location"]
            extension = Path(filename).suffix
            out_filename = "{}_{}{}".format(str(output_file), str(idx), extension)

            face = detected_face["face"]
            resized_face = face.aligned_face

            face.hash, img = hash_encode_image(resized_face, extension)
            save_queue.put((out_filename, img))
            final_faces.append(face.to_alignment())
        self.alignments.data[os.path.basename(filename)] = final_faces


class Plugins():
    """ Detector and Aligner Plugins and queues """
    def __init__(self, arguments):
        logger.debug("Initializing %s", self.__class__.__name__)
        self.args = arguments
        self.detector = self.load_detector()
        self.aligner = self.load_aligner()
        self.is_parallel = self.set_parallel_processing()

        self.process_detect = None
        self.process_align = None
        self.add_queues()
        logger.debug("Initialized %s", self.__class__.__name__)

    def set_parallel_processing(self):
        """ Set whether to run detect and align together or separately """
        detector_vram = self.detector.vram
        aligner_vram = self.aligner.vram
        gpu_stats = GPUStats()
        if (detector_vram == 0
                or aligner_vram == 0
                or gpu_stats.device_count == 0):
            logger.debug("At least one of aligner or detector have no VRAM requirement. "
                         "Enabling parallel processing.")
            return True

        if hasattr(self.args, "multiprocess") and not self.args.multiprocess:
            logger.info("NB: Parallel processing disabled.You may get faster "
                        "extraction speeds by enabling it with the -mp switch")
            return False

        required_vram = detector_vram + aligner_vram + 320  # 320MB buffer
        stats = gpu_stats.get_card_most_free()
        free_vram = int(stats["free"])
        logger.verbose("%s - %sMB free of %sMB",
                       stats["device"],
                       free_vram,
                       int(stats["total"]))
        if free_vram <= required_vram:
            logger.warning("Not enough free VRAM for parallel processing. "
                           "Switching to serial")
            return False
        return True

    def add_queues(self):
        """ Add the required processing queues to Queue Manager """
        for task in ("load", "detect", "align", "save"):
            size = 0
            if task == "load" or (not self.is_parallel and task == "detect"):
                size = 100
            queue_manager.add_queue(task, maxsize=size)

    def load_detector(self):
        """ Set global arguments and load detector plugin """
        detector_name = self.args.detector.replace("-", "_").lower()
        logger.debug("Loading Detector: '%s'", detector_name)
        # Rotation
        rotation = None
        if hasattr(self.args, "rotate_images"):
            rotation = self.args.rotate_images

        detector = PluginLoader.get_detector(detector_name)(
            loglevel=self.args.loglevel,
            rotation=rotation)

        return detector

    def load_aligner(self):
        """ Set global arguments and load aligner plugin """
        aligner_name = self.args.aligner.replace("-", "_").lower()
        logger.debug("Loading Aligner: '%s'", aligner_name)

        aligner = PluginLoader.get_aligner(aligner_name)(
            loglevel=self.args.loglevel)

        return aligner

    def launch_aligner(self):
        """ Launch the face aligner """
        logger.debug("Launching Aligner")
        out_queue = queue_manager.get_queue("align")
        kwargs = {"in_queue": queue_manager.get_queue("detect"),
                  "out_queue": out_queue}

        self.process_align = SpawnProcess(self.aligner.run, **kwargs)
        event = self.process_align.event
        self.process_align.start()

        # Wait for Aligner to take it's VRAM
        # The first ever load of the model for FAN has reportedly taken
        # up to 3-4 minutes, hence high timeout.
        # TODO investigate why this is and fix if possible
        for mins in reversed(range(5)):
            event.wait(60)
            if event.is_set():
                break
            if mins == 0:
                raise ValueError("Error initializing Aligner")
            logger.info("Waiting for Aligner... Time out in %s minutes", mins)

        logger.debug("Launched Aligner")

    def launch_detector(self):
        """ Launch the face detector """
        logger.debug("Launching Detector")
        out_queue = queue_manager.get_queue("detect")
        kwargs = {"in_queue": queue_manager.get_queue("load"),
                  "out_queue": out_queue}

        mp_func = PoolProcess if self.detector.parent_is_pool else SpawnProcess
        self.process_detect = mp_func(self.detector.run, **kwargs)

        event = None
        if hasattr(self.process_detect, "event"):
            event = self.process_detect.event

        self.process_detect.start()

        if event is None:
            logger.debug("Launched Detector")
            return

        for mins in reversed(range(5)):
            event.wait(60)
            if event.is_set():
                break
            if mins == 0:
                raise ValueError("Error initializing Detector")
            logger.info("Waiting for Detector... Time out in %s minutes", mins)

        logger.debug("Launched Detector")

    def detect_faces(self, extract_pass="detect"):
        """ Detect faces from in an image """
        logger.debug("Running Detection. Pass: '%s'", extract_pass)
        if self.is_parallel or extract_pass == "align":
            out_queue = queue_manager.get_queue("align")
        if not self.is_parallel and extract_pass == "detect":
            out_queue = queue_manager.get_queue("detect")

        while True:
            try:
                faces = out_queue.get(True, 1)
                if faces == "EOF":
                    break
                if isinstance(faces, dict) and faces.get("exception"):
                    pid = faces["exception"][0]
                    t_back = faces["exception"][1].getvalue()
                    err = "Error in child process {}. {}".format(pid, t_back)
                    raise Exception(err)
            except QueueEmpty:
                continue

            yield faces
        logger.debug("Detection Complete")
